{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从一份良好的数据集中学习，训练过程中完全不更新数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20000,\n",
       " (tensor([0.9019, 0.4320, 0.8704]),\n",
       "  tensor([-1.5736]),\n",
       "  tensor([0.9651]),\n",
       "  tensor([0.8817, 0.4719, 0.8943]),\n",
       "  tensor([0])))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#封装数据集\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "\n",
    "    def __init__(self):\n",
    "        import numpy as np\n",
    "        data = np.loadtxt('离线学习数据.txt')\n",
    "        self.state = torch.FloatTensor(data[:, :3]).reshape(-1, 3)\n",
    "        self.action = torch.FloatTensor(data[:, 3]).reshape(-1, 1)\n",
    "        self.reward = torch.FloatTensor(data[:, 4]).reshape(-1, 1)\n",
    "        self.next_state = torch.FloatTensor(data[:, 5:8]).reshape(-1, 3)\n",
    "        self.over = torch.LongTensor(data[:, 8]).reshape(-1, 1)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.state[i], self.action[i], self.reward[i], self.next_state[\n",
    "            i], self.over[i]\n",
    "\n",
    "\n",
    "dataset = Dataset()\n",
    "\n",
    "len(dataset), dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "312"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=64,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<__main__.CQL at 0x7f93e8b8a710>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from sac import SAC\n",
    "\n",
    "\n",
    "#定义CQL算法\n",
    "class CQL(SAC):\n",
    "\n",
    "    def get_loss_cql(self, state, next_state, value):\n",
    "        #把state,next_state复制5遍\n",
    "        state = state.unsqueeze(dim=1).repeat(1, 5, 1).reshape(-1, 3)\n",
    "        next_state = next_state.unsqueeze(1).repeat(1, 5, 1).reshape(-1, 3)\n",
    "\n",
    "        #计算动作和熵\n",
    "        rand_action = torch.empty([len(state), 1]).uniform_(-1, 1)\n",
    "        curr_action, _ = self.get_action_entropy(state)\n",
    "        next_action, _ = self.get_action_entropy(next_state)\n",
    "\n",
    "        #计算三份动作分别的value\n",
    "        value_rand = self.model_value(torch.cat([state, rand_action], dim=1))\n",
    "        value_curr = self.model_value(torch.cat([state, curr_action], dim=1))\n",
    "        value_next = self.model_value(torch.cat([state, next_action], dim=1))\n",
    "\n",
    "        #拼合三份value\n",
    "        value_cat = torch.cat([value_rand, value_curr, value_next], dim=1)\n",
    "        loss_cat = value_cat.exp().sum(dim=1) + 1e-8\n",
    "        loss_cat = loss_cat.log().mean()\n",
    "\n",
    "        #在value loss上增加上这一部分\n",
    "        return 5.0 * (loss_cat - value.mean())\n",
    "\n",
    "\n",
    "cql = CQL()\n",
    "\n",
    "cql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAddUlEQVR4nO3df3BTdb7/8VfSNKG0TUpbmtCllc7lZ5cfasESvSP7lUp1qytS77BcBruKOmLhgni5a13Eu17H8sUZXXUVdsZZ4c4d7d66FgQBrQWKDKFAoVIKVFxZW4Gk/DBJW2jaJu/7B/as0fIjbdpPA6/HTGbsOZ+k76B5epKTBJ2ICIiI+ple9QBEdGNifIhICcaHiJRgfIhICcaHiJRgfIhICcaHiJRgfIhICcaHiJRgfIhICWXxeeuttzBixAgMGjQI2dnZ2Lt3r6pRiEgBJfH5y1/+gqVLl+KFF17AgQMHMGnSJOTm5qKpqUnFOESkgE7FB0uzs7MxZcoU/PGPfwQABAIBpKWlYdGiRXj22Wevev1AIIBTp04hPj4eOp2ur8clomskImhubkZqair0+isf2xj6aSZNe3s7qqurUVRUpG3T6/XIycmBw+Ho9jo+nw8+n0/7+eTJk8jMzOzzWYmoZxobGzF8+PArrun3+Jw9exZ+vx9WqzVou9VqxbFjx7q9TnFxMX7/+9//ZHtjYyPMZnOfzElEofN6vUhLS0N8fPxV1/Z7fHqiqKgIS5cu1X7uuoNms5nxIRqAruXlkH6PT3JyMqKiouByuYK2u1wu2Gy2bq9jMplgMpn6Yzwi6if9frbLaDQiKysLFRUV2rZAIICKigrY7fb+HoeIFFHytGvp0qUoKCjA5MmTcdttt+EPf/gDWltb8cgjj6gYh4gUUBKf2bNn48yZM1ixYgWcTiduvvlmbN269ScvQhPR9UvJ+3x6y+v1wmKxwOPx8AVnogEklMcmP9tFREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREqEHJ+dO3fi/vvvR2pqKnQ6HdavXx+0X0SwYsUKDBs2DDExMcjJycHx48eD1pw/fx5z586F2WxGQkIC5s+fj5aWll7dESKKLCHHp7W1FZMmTcJbb73V7f5Vq1bhjTfewJo1a1BVVYXY2Fjk5uaira1NWzN37lzU1dWhvLwcmzZtws6dO/HEE0/0/F4QUeSRXgAgZWVl2s+BQEBsNpu88sor2ja32y0mk0nef/99ERE5cuSIAJB9+/Zpa7Zs2SI6nU5Onjx5Tb/X4/EIAPF4PL0Zn4jCLJTHZlhf8zlx4gScTidycnK0bRaLBdnZ2XA4HAAAh8OBhIQETJ48WVuTk5MDvV6Pqqqqbm/X5/PB6/UGXYgosoU1Pk6nEwBgtVqDtlutVm2f0+lESkpK0H6DwYDExERtzY8VFxfDYrFol7S0tHCOTUQKRMTZrqKiIng8Hu3S2NioeiQi6qWwxsdmswEAXC5X0HaXy6Xts9lsaGpqCtrf2dmJ8+fPa2t+zGQywWw2B12IKLKFNT4ZGRmw2WyoqKjQtnm9XlRVVcFutwMA7HY73G43qqurtTXbtm1DIBBAdnZ2OMchogHMEOoVWlpa8NVXX2k/nzhxAjU1NUhMTER6ejqWLFmCl156CaNGjUJGRgaef/55pKamYubMmQCAcePG4Z577sHjjz+ONWvWoKOjAwsXLsSvf/1rpKamhu2OEdEAF+qptO3btwuAn1wKCgpE5NLp9ueff16sVquYTCaZPn261NfXB93GuXPnZM6cORIXFydms1keeeQRaW5uvuYZeKqdaGAK5bGpExFR2L4e8Xq9sFgs8Hg8fP2HaAAJ5bEZEWe7iOj6w/gQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpYVA9APWeiED8fsDvv7QhKgq6qCjodDq1gxFdAeMT4QLt7fDW1OD89u248M03AIDBGRmIy8xE/MSJGPSzn0Gn5wEuDTyMTwQL+Hw4XVqKpo8+QqCtTdvu+/ZbfLdrFwwJCbDOnImk6dNhiI/nkRANKPxfYoSSQABntm6Fq6wsKDz/WCDo/O47nFy3Dl+98AK++/xz+LtbR6QI4xOh2s+cgWvDBkhHx5UXBgK48Le/4cSrr+Lvr76K1q++ggQCEJH+GZToMvi0K0Jd+Nvf0HH27LVfIRCAe88eNB8+jKS77oJ15kxEJyXxqRgpE9KRT3FxMaZMmYL4+HikpKRg5syZqK+vD1rT1taGwsJCJCUlIS4uDvn5+XC5XEFrGhoakJeXh8GDByMlJQXLli1DZ2dn7+8NXZW/pQVNGzfi+IoVOPvpp/C3tvIoiJQIKT6VlZUoLCzEnj17UF5ejo6ODsyYMQOtra3amqeffhobN25EaWkpKisrcerUKcyaNUvb7/f7kZeXh/b2duzevRvr1q3D2rVrsWLFivDdqxuALioK6OlRiwjavv0WDWvW4Ph//ieaDx1C4GpP34jCTCe9+N/emTNnkJKSgsrKStx5553weDwYOnQo3nvvPTz00EMAgGPHjmHcuHFwOByYOnUqtmzZgvvuuw+nTp2C1WoFAKxZswa//e1vcebMGRiNxqv+Xq/XC4vFAo/HA7PZ3NPxI1pnczPqn3sObd+fXu8NfUwMEqZOhfWBBxAzYgRPzVOPhfLY7NV/ZR6PBwCQmJgIAKiurkZHRwdycnK0NWPHjkV6ejocDgcAwOFwYMKECVp4ACA3Nxderxd1dXXd/h6fzwev1xt0udFFxcXhZ3PnIio+vte3Fbh4Eee3b8eXy5fD+de/osPt5lMx6nM9jk8gEMCSJUtwxx13YPz48QAAp9MJo9GIhISEoLVWqxVOp1Nb88PwdO3v2ted4uJiWCwW7ZKWltbTsa8bOp0OlilTMPzRR6GLjg7LbfpbWnDqf/4HXy5fjvOVlQi0tzNC1Gd6HJ/CwkIcPnwYJSUl4ZynW0VFRfB4PNqlsbGxz39nJNBFRSFx2jSkPfYYTKmp4blREbQ1NOCbN9/E16tWofXYMYjfzwhR2PXoVPvChQuxadMm7Ny5E8OHD9e222w2tLe3w+12Bx39uFwu2Gw2bc3evXuDbq/rbFjXmh8zmUwwmUw9GfW6pzcYkHzPPTBnZcG1YQPO79gBf3Nzr29XOjrg2bsXLXV1GHLnnRj20EOITk7mqXkKm5COfEQECxcuRFlZGbZt24aMjIyg/VlZWYiOjkZFRYW2rb6+Hg0NDbDb7QAAu92O2tpaNDU1aWvKy8thNpuRmZnZm/tyw9LpdDClpCDt0Ucx+r/+C4m/+AX0gwaF5bb9ra04u3Urvly+HE0ffYTOlhYeBVFYhHS266mnnsJ7772HDRs2YMyYMdp2i8WCmJgYAMCCBQuwefNmrF27FmazGYsWLQIA7N69G8ClU+0333wzUlNTsWrVKjidTsybNw+PPfYYXn755Wuag2e7Lk9ELh21HDiApvXr0frll5BwvYdKr0dcZiZs+fmInzAB+ms4M0k3llAemyHF53KH3O+++y5+85vfALj0JsNnnnkG77//Pnw+H3Jzc/H2228HPaX65ptvsGDBAuzYsQOxsbEoKCjAypUrYTBc27NAxufqRASBixfx3e7dcJWVoS2Mr5PpjEZYsrIwbPZsxGRk8KkYafosPgMF43PtRAQdZ8/izJYtOPvZZ+h0u8N229GJiUjOzUVybi6ihwxhhIjxoWBd/4rbvv0Wrg0b8N2uXQhcuBC22x90000Y9i//goTsbOiMRkboBsb40GUFOjvRUlcH5//+L1qOHLn0DYhhoDMYED9hAqyzZiH+5z+/9G2KjNANh/Ghq/JfvAj33r1oWr8eF77+GgjTfwZRgwcjafp0DL3vPpisVn5U4wbD+NA1ERF0ejw4s3Urzn76aWhf0XEV0cnJsD74IJJ+8QsYwvAREIoMjA+FRAIB+JxOnPn4Y5ytqAjf60F6PWJHj4b1wQdhufVW6PlG0ese40MhExEgEEDLkSNwlpWh+Ysvrv4tiddIFxWFhNtvh+2hhxBz002ATsfXg65TjA/1WNebFL0HDsD517+i9fhxIBAIy20bLBYk3XUXhublwTh0KAN0HWJ8qNdEBP6WFpz97DOc+fhjtJ85E7YXpQcNHw7bQw8hwW6HftAgRug6wvhQ2IgI2puacGbLFpyrqEDn99/h1Fs6gwGxY8ci9V//FbFjx0J/je9up4GN8aGwE78frV99BdeHH8J74AACPl9YbjcqNhZD/vmfkXLffTDZbHyTYoRjfKhPiAiksxPNhw7B+eGHaDl8OHzvD4qNRdzPf46UvDzET5x46TuqKeIwPtSnRAT+Cxdwfvt2NG3aBN+pU2G77ajBg5H68MMYmpvLAEWgfvsOZ7ox6XQ6GGJjMTQvD6NffhnD5sxBVGxsyLcjIvjO58P+s2dx3OtF4Puonfzv/4bnwAF+b9B1jq/yUY/pdDoYExMxbPZsWKZMQdOGDXBXVXX/1zf/iIigobUVzx88iHqPB7EGAx4bPRqzMzKACxfgKitD/PjxiPr+e6Lo+sMjH+o1nV6P2JEjcdO//Rv+aflymG+9FbqrnL0SAP+/thZH3G74ReDt6MAfjx7F4e++AwBc+Ppr+C9e7IfpSRUe+VDY6KOjET9hAmJHjsR3u3bBtXHjpS8xu8ybFL0/egd1eyAAX5g+ZU8DH498KKx0Ot2lT7bffTdGv/QSfjZvHqKTk3+6DsD/s9lg+MFp9dFmM26KiwMARJlM/ET8dY5HPtQndDodDGbzpQ+VTpkC14YNcO/eDf/3f7W2TqdDwciRiI+OxmenT2NYTAweHz0aKd9/8f2QadNg4JnM6xpPtVOf6/rQauvx43CWlsKzf7/2/iARgeDSkRBwKUpxmZnI+I//gPH7vwmXIkcoj00e+VCf0+l0QFQUYseMQca//zvcVVVwrV8P3+nTCLS1Qfd9iHTR0TDfcguGz5/P8NwAGB/qNzqdDlExMUicNg2WrCx0nDuH5sOH4XM6oY+JQdy4cYgbN46n128QjA/1O51OB0N8PAzx8YgZMUL1OKQITycQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpEVJ8Vq9ejYkTJ8JsNsNsNsNut2PLli3a/ra2NhQWFiIpKQlxcXHIz8+Hy+UKuo2Ghgbk5eVh8ODBSElJwbJly9DZ2Rmee0NEESOk+AwfPhwrV65EdXU19u/fj7vuugsPPPAA6urqAABPP/00Nm7ciNLSUlRWVuLUqVOYNWuWdn2/34+8vDy0t7dj9+7dWLduHdauXYsVK1aE914R0cAnvTRkyBB55513xO12S3R0tJSWlmr7jh49KgDE4XCIiMjmzZtFr9eL0+nU1qxevVrMZrP4fL7L/o62tjbxeDzapbGxUQCIx+Pp7fhEFEYej+eaH5s9fs3H7/ejpKQEra2tsNvtqK6uRkdHB3JycrQ1Y8eORXp6OhwOBwDA4XBgwoQJsFqt2prc3Fx4vV7t6Kk7xcXFsFgs2iUtLa2nYxPRABFyfGpraxEXFweTyYQnn3wSZWVlyMzMhNPphNFoREJCQtB6q9UKp9MJAHA6nUHh6drfte9yioqK4PF4tEtjY2OoYxPRABPydziPGTMGNTU18Hg8+OCDD1BQUIDKysq+mE1jMplgMpn69HcQUf8KOT5GoxEjR44EAGRlZWHfvn14/fXXMXv2bLS3t8Ptdgcd/bhcLthsNgCAzWbD3r17g26v62xY1xoiujH0+n0+gUAAPp8PWVlZiI6ORkVFhbavvr4eDQ0NsNvtAAC73Y7a2lo0NTVpa8rLy2E2m5GZmdnbUYgogoR05FNUVIR7770X6enpaG5uxnvvvYcdO3bgk08+gcViwfz587F06VIkJibCbDZj0aJFsNvtmDp1KgBgxowZyMzMxLx587Bq1So4nU4sX74chYWFfFpFdIMJKT5NTU14+OGHcfr0aVgsFkycOBGffPIJ7r77bgDAa6+9Br1ej/z8fPh8PuTm5uLtt9/Wrh8VFYVNmzZhwYIFsNvtiI2NRUFBAV588cXw3isiGvD4d7UTUdiE8tjkZ7uISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISIlexWflypXQ6XRYsmSJtq2trQ2FhYVISkpCXFwc8vPz4XK5gq7X0NCAvLw8DB48GCkpKVi2bBk6Ozt7MwoRRZgex2ffvn3405/+hIkTJwZtf/rpp7Fx40aUlpaisrISp06dwqxZs7T9fr8feXl5aG9vx+7du7Fu3TqsXbsWK1as6Pm9IKLIIz3Q3Nwso0aNkvLycpk2bZosXrxYRETcbrdER0dLaWmptvbo0aMCQBwOh4iIbN68WfR6vTidTm3N6tWrxWw2i8/n6/b3tbW1icfj0S6NjY0CQDweT0/GJ6I+4vF4rvmx2aMjn8LCQuTl5SEnJydoe3V1NTo6OoK2jx07Funp6XA4HAAAh8OBCRMmwGq1amtyc3Ph9XpRV1fX7e8rLi6GxWLRLmlpaT0Zm4gGkJDjU1JSggMHDqC4uPgn+5xOJ4xGIxISEoK2W61WOJ1Obc0Pw9O1v2tfd4qKiuDxeLRLY2NjqGMT0QBjCGVxY2MjFi9ejPLycgwaNKivZvoJk8kEk8nUb7+PiPpeSEc+1dXVaGpqwq233gqDwQCDwYDKykq88cYbMBgMsFqtaG9vh9vtDrqey+WCzWYDANhstp+c/er6uWsNEV3/QorP9OnTUVtbi5qaGu0yefJkzJ07V/vn6OhoVFRUaNepr69HQ0MD7HY7AMBut6O2thZNTU3amvLycpjNZmRmZobpbhHRQBfS0674+HiMHz8+aFtsbCySkpK07fPnz8fSpUuRmJgIs9mMRYsWwW63Y+rUqQCAGTNmIDMzE/PmzcOqVavgdDqxfPlyFBYW8qkV0Q0kpPhci9deew16vR75+fnw+XzIzc3F22+/re2PiorCpk2bsGDBAtjtdsTGxqKgoAAvvvhiuEchogFMJyKieohQeb1eWCwWeDwemM1m1eMQ0fdCeWzys11EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpIRB9QA9ISIAAK/Xq3gSIvqhrsdk12P0SiIyPufOnQMApKWlKZ6EiLrT3NwMi8VyxTURGZ/ExEQAQENDw1Xv4EDj9XqRlpaGxsZGmM1m1eNcM87dvyJ1bhFBc3MzUlNTr7o2IuOj1196qcpisUTUv5gfMpvNETk75+5fkTj3tR4Q8AVnIlKC8SEiJSIyPiaTCS+88AJMJpPqUUIWqbNz7v4VqXOHQifXck6MiCjMIvLIh4giH+NDREowPkSkBONDREowPkSkRETG56233sKIESMwaNAgZGdnY+/evUrn2blzJ+6//36kpqZCp9Nh/fr1QftFBCtWrMCwYcMQExODnJwcHD9+PGjN+fPnMXfuXJjNZiQkJGD+/PloaWnp07mLi4sxZcoUxMfHIyUlBTNnzkR9fX3Qmra2NhQWFiIpKQlxcXHIz8+Hy+UKWtPQ0IC8vDwMHjwYKSkpWLZsGTo7O/ts7tWrV2PixInau3/tdju2bNkyoGfuzsqVK6HT6bBkyZKImz0sJMKUlJSI0WiUP//5z1JXVyePP/64JCQkiMvlUjbT5s2b5Xe/+518+OGHAkDKysqC9q9cuVIsFousX79evvjiC/nVr34lGRkZcvHiRW3NPffcI5MmTZI9e/bI559/LiNHjpQ5c+b06dy5ubny7rvvyuHDh6WmpkZ++ctfSnp6urS0tGhrnnzySUlLS5OKigrZv3+/TJ06VW6//XZtf2dnp4wfP15ycnLk4MGDsnnzZklOTpaioqI+m/ujjz6Sjz/+WL788kupr6+X5557TqKjo+Xw4cMDduYf27t3r4wYMUImTpwoixcv1rZHwuzhEnHxue2226SwsFD72e/3S2pqqhQXFyuc6h9+HJ9AICA2m01eeeUVbZvb7RaTySTvv/++iIgcOXJEAMi+ffu0NVu2bBGdTicnT57st9mbmpoEgFRWVmpzRkdHS2lpqbbm6NGjAkAcDoeIXAqvXq8Xp9OprVm9erWYzWbx+Xz9NvuQIUPknXfeiYiZm5ubZdSoUVJeXi7Tpk3T4hMJs4dTRD3tam9vR3V1NXJycrRter0eOTk5cDgcCie7vBMnTsDpdAbNbLFYkJ2drc3scDiQkJCAyZMna2tycnKg1+tRVVXVb7N6PB4A//jWgOrqanR0dATNPnbsWKSnpwfNPmHCBFitVm1Nbm4uvF4v6urq+nxmv9+PkpIStLa2wm63R8TMhYWFyMvLC5oRiIw/73CKqE+1nz17Fn6/P+gPHgCsViuOHTumaKorczqdANDtzF37nE4nUlJSgvYbDAYkJiZqa/paIBDAkiVLcMcdd2D8+PHaXEajEQkJCVecvbv71rWvr9TW1sJut6OtrQ1xcXEoKytDZmYmampqBuzMAFBSUoIDBw5g3759P9k3kP+8+0JExYf6TmFhIQ4fPoxdu3apHuWajBkzBjU1NfB4PPjggw9QUFCAyspK1WNdUWNjIxYvXozy8nIMGjRI9TjKRdTTruTkZERFRf3k1X+XywWbzaZoqivrmutKM9tsNjQ1NQXt7+zsxPnz5/vlfi1cuBCbNm3C9u3bMXz4cG27zWZDe3s73G73FWfv7r517esrRqMRI0eORFZWFoqLizFp0iS8/vrrA3rm6upqNDU14dZbb4XBYIDBYEBlZSXeeOMNGAwGWK3WATt7X4io+BiNRmRlZaGiokLbFggEUFFRAbvdrnCyy8vIyIDNZgua2ev1oqqqSpvZbrfD7XajurpaW7Nt2zYEAgFkZ2f32WwigoULF6KsrAzbtm1DRkZG0P6srCxER0cHzV5fX4+Ghoag2Wtra4PiWV5eDrPZjMzMzD6b/ccCgQB8Pt+Annn69Omora1FTU2Ndpk8eTLmzp2r/fNAnb1PqH7FO1QlJSViMplk7dq1cuTIEXniiSckISEh6NX//tbc3CwHDx6UgwcPCgB59dVX5eDBg/LNN9+IyKVT7QkJCbJhwwY5dOiQPPDAA92ear/lllukqqpKdu3aJaNGjerzU+0LFiwQi8UiO3bskNOnT2uXCxcuaGuefPJJSU9Pl23btsn+/fvFbreL3W7X9ned+p0xY4bU1NTI1q1bZejQoX166vfZZ5+VyspKOXHihBw6dEieffZZ0el08umnnw7YmS/nh2e7Im323oq4+IiIvPnmm5Keni5Go1Fuu+022bNnj9J5tm/fLgB+cikoKBCRS6fbn3/+ebFarWIymWT69OlSX18fdBvnzp2TOXPmSFxcnJjNZnnkkUekubm5T+fubmYA8u6772prLl68KE899ZQMGTJEBg8eLA8++KCcPn066Hb+/ve/y7333isxMTGSnJwszzzzjHR0dPTZ3I8++qjcdNNNYjQaZejQoTJ9+nQtPAN15sv5cXwiafbe4vf5EJESEfWaDxFdPxgfIlKC8SEiJRgfIlKC8SEiJRgfIlKC8SEiJRgfIlKC8SEiJRgfIlKC8SEiJf4PnZ0nJjhV1koAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(\n",
    "            [action * 2])\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/cuda117/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "49.54885690099563"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    data = []\n",
    "    reward_sum = 0\n",
    "\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据概率采样\n",
    "        mu, sigma = cql.model_action(torch.FloatTensor(state).reshape(1, 3))\n",
    "        action = random.normalvariate(mu=mu.item(), sigma=sigma.item())\n",
    "\n",
    "        next_state, reward, over = env.step(action)\n",
    "\n",
    "        data.append((state, action, reward, next_state, over))\n",
    "        reward_sum += reward\n",
    "\n",
    "        state = next_state\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    return data, reward_sum\n",
    "\n",
    "\n",
    "play()[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2.21518611907959 9.759054348477896\n",
      "10 14.770238876342773 57.34436458273679\n",
      "20 22.05959129333496 35.228655545379\n",
      "30 28.099075317382812 89.85440499866809\n",
      "40 36.08259582519531 114.27782861153207\n",
      "50 44.10224151611328 157.9843883786122\n",
      "60 46.83131790161133 171.1354479880918\n",
      "70 50.894508361816406 181.79026115481287\n",
      "80 53.198081970214844 176.78876272448744\n",
      "90 54.12747573852539 183.08075296407867\n",
      "100 55.96063995361328 179.21220952184632\n",
      "110 58.04907989501953 184.37389171490022\n",
      "120 58.42814254760742 177.3067626818911\n",
      "130 58.266326904296875 182.22527272364547\n",
      "140 60.356361389160156 173.62140559932726\n",
      "150 58.26789093017578 183.66146342147925\n",
      "160 60.373321533203125 173.5510082565807\n",
      "170 59.34394836425781 178.38124251421726\n",
      "180 59.59429931640625 178.6704264063767\n",
      "190 60.33028030395508 174.081362401392\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    for epoch in range(200):\n",
    "        for i, (state, action, reward, next_state, over) in enumerate(loader):\n",
    "            cql.train_value(state, action, reward, next_state, over)\n",
    "            cql.train_action(state)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            value = cql.model_value(torch.cat([state, action],\n",
    "                                              dim=1)).mean().item()\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, value, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAe6klEQVR4nO3de3CTZd438O+d5tBjUltsYqUFdmCBPpwLlOjs6kqkul1X1u6uy/Jo12XwFQMDdh9m6S7iu+yhvDijq66iM67CP9h9YESFBdy+BYu+hIOVSinQBUVbgaRITdIDTdLkev+A3g/hZAJJr4R+PzP3DL2vX5rfDebrnes+KUIIASKiAaaR3QARDU4MHyKSguFDRFIwfIhICoYPEUnB8CEiKRg+RCQFw4eIpGD4EJEUDB8ikkJa+Lz88ssYPnw4UlNTUVJSgn379slqhYgkkBI+//jHP1BZWYlnnnkGn3zyCSZOnIjS0lK0t7fLaIeIJFBkXFhaUlKCadOm4W9/+xsAIBQKoaCgAIsWLcKyZcu+9fWhUAinTp1CVlYWFEWJd7tEFCEhBDo7O5Gfnw+N5tr7NtoB6knl9/vR0NCAqqoqdZ1Go4HNZoPD4bjia3w+H3w+n/rzyZMnUVRUFPdeiej6tLW1YejQodesGfDw+frrrxEMBmE2m8PWm81mHD169Iqvqa6uxh/+8IfL1re1tcFoNMalTyKKntfrRUFBAbKysr61dsDD53pUVVWhsrJS/bl/A41GI8OHKAFFMh0y4OEzZMgQpKSkwOVyha13uVywWCxXfI3BYIDBYBiI9ohogAz40S69Xo/i4mLU1dWp60KhEOrq6mC1Wge6HSKSRMrXrsrKSlRUVGDq1KmYPn06/vrXv6K7uxuPPfaYjHaISAIp4fPwww/jzJkzWLFiBZxOJyZNmoTt27dfNglNRDcvKef53Civ1wuTyQSPx8MJZ6IEEs1nk9d2EZEUDB8ikoLhQ0RSMHyISAqGDxFJwfAhIikYPkQkBcOHiKRg+BCRFAwfIpKC4UNEUjB8iEgKhg8RScHwISIpGD5EJAXDh4ikYPgQkRQMHyKSguFDRFIwfIhICoYPEUnB8CEiKRg+RCQFw4eIpGD4EJEUDB8ikoLhQ0RSMHyISAqGDxFJwfAhIikYPkQkBcOHiKRg+BCRFAwfIpKC4UNEUkQdPrt27cIDDzyA/Px8KIqCd955J2xcCIEVK1bgtttuQ1paGmw2G44dOxZW09HRgblz58JoNCI7Oxvz5s1DV1fXDW0IESWXqMOnu7sbEydOxMsvv3zF8dWrV+PFF1/Eq6++ir179yIjIwOlpaXo7e1Va+bOnYvm5mbU1tZiy5Yt2LVrFx5//PHr3woiSj7iBgAQmzZtUn8OhULCYrGIZ599Vl3ndruFwWAQb731lhBCiMOHDwsAYv/+/WrNtm3bhKIo4uTJkxG9r8fjEQCEx+O5kfaJKMai+WzGdM7nxIkTcDqdsNls6jqTyYSSkhI4HA4AgMPhQHZ2NqZOnarW2Gw2aDQa7N2794q/1+fzwev1hi1ElNxiGj5OpxMAYDabw9abzWZ1zOl0Ii8vL2xcq9UiJydHrblUdXU1TCaTuhQUFMSybSKSICmOdlVVVcHj8ahLW1ub7JaI6AbFNHwsFgsAwOVyha13uVzqmMViQXt7e9h4X18fOjo61JpLGQwGGI3GsIWIkltMw2fEiBGwWCyoq6tT13m9XuzduxdWqxUAYLVa4Xa70dDQoNbs2LEDoVAIJSUlsWyHiBKYNtoXdHV14fjx4+rPJ06cQGNjI3JyclBYWIglS5bgT3/6E0aNGoURI0bg6aefRn5+PmbPng0AGDt2LO677z7Mnz8fr776KgKBABYuXIhf/OIXyM/Pj9mGEVGCi/ZQ2s6dOwWAy5aKigohxPnD7U8//bQwm83CYDCImTNnipaWlrDfcfbsWTFnzhyRmZkpjEajeOyxx0RnZ2fEPfBQO1FiiuazqQghhMTsuy5erxcmkwkej4fzP0QJJJrPZlIc7SKimw/Dh4ikYPgQkRRRH+0iisTVphIVRRngTihRMXwo5kQwiO7jx/HNrl3oOXECikaD9FGjkPP97yNt2DAoGu5wE8OHYiwUCKB9yxY4N25EsLNTXd958CA6du5E/n/+J3J/8AMoKSkSu6REwP8FUcyIUAhnd+7E6bfeCguefoGODnz1xhtw799/1a9lNHgwfChmAh0dcP73fyN00Y3jLhXs6sLp9esR7O4ewM4oETF8KGY6m5vhv+Si4Ss598UXOPf55wPQESUyhg/FhBACCIUirg94PHHshpIBw4diQwh07NoVcbn7wp0tafBi+FDsRDGJ3NfVxUnnQY7hQ7GhKMgcNy7i8j6vF8Genjg2RImO4UMxo8/Jibi298sv4b/KPbtpcGD4UEwoigK92QxNampE9SIUQigQiHNXlMgYPhQzqUOHIiUzM7JiIdDz2WfxbYgSGsOHYiYlNRUanS7i+q7DhznpPIgxfChmFK0WGaNHR1wf8vujOjeIbi4MH4odjQaphYURl5/78kv08emzgxbDh2JGURRo09OBCO/ZEzh7lofbBzGGD8VUxtixSElPj6xYCPTxMotBi+FDMaXLzYWijew2UaKvD95PP41zR5SoGD4UUxqtFtooHmfU53bziNcgxfChmNIYDMgsKoq43t/RAcGTDQclhg/FlkYT1Z5P99GjnHQepBg+FFOKosCQnw9EeI/mUCCAkM8X564oETF8KOYyRo+GJtJJ50AAvV99FeeOKBExfCjmUlJTgQgfjyMCAfQcP85J50GI4UMxl5KejrThwyOu7+vqiupGZHRzYPhQzGlSU5Ganx9xfc9nn/GI1yDE8KGYUzSaiO/rAwC+U6cQ6uuLY0eUiBg+FBfGKVMin/fp6+NlFoMQw4fiwmA2Q4nwAtNgdzdvLDYIMXwoLjQ6HTRpaRHVimCQl1kMQgwfigvdkCFIHzUq4nqf08kjXoNMVOFTXV2NadOmISsrC3l5eZg9ezZaWlrCanp7e2G325Gbm4vMzEyUl5fD5XKF1bS2tqKsrAzp6enIy8vD0qVL0ccJx5uKotWev7dPhDqbmyGCwTh2RIkmqvCpr6+H3W7Hnj17UFtbi0AggFmzZqG7u1uteeqpp7B582Zs2LAB9fX1OHXqFB566CF1PBgMoqysDH6/H7t378a6deuwdu1arFixInZbRdIpioKMKC4wDfb08HD7IKOIG/iifebMGeTl5aG+vh7f//734fF4cOutt2L9+vX46U9/CgA4evQoxo4dC4fDgRkzZmDbtm340Y9+hFOnTsFsNgMAXn31Vfz2t7/FmTNnoNfrv/V9vV4vTCYTPB4PjFFcxEgDy3vwII4tXx5RrSY1FSOfeQZZ//Efce6K4imaz+YNzfl4LhwezbnwsLiGhgYEAgHYbDa1ZsyYMSgsLITjwrO5HQ4Hxo8frwYPAJSWlsLr9aK5ufmK7+Pz+eD1esMWSnwpGRkR31gs1NuLQEcHJ50HkesOn1AohCVLluDOO+/EuAuPyXU6ndDr9cjOzg6rNZvNcF54OqXT6QwLnv7x/rErqa6uhslkUpeCgoLrbZsGkCEv7/wV7hHizeQHl+sOH7vdjkOHDqGmpiaW/VxRVVUVPB6PurS1tcX9PenGpWRkQJuVFXG9Z9++OHZDiSayfeJLLFy4EFu2bMGuXbswdOhQdb3FYoHf74fb7Q7b+3G5XLBYLGrNvkv+I+s/GtZfcymDwQCDwXA9rZJMigJdFM9v7/N6zz/HK8J7AVFyi2rPRwiBhQsXYtOmTdixYwdGjBgRNl5cXAydToe6ujp1XUtLC1pbW2G1WgEAVqsVTU1NaG9vV2tqa2thNBpRFMXREUoOxokTI64N9vSgr7Mzjt1QIolqz8dut2P9+vV49913kZWVpc7RmEwmpKWlwWQyYd68eaisrEROTg6MRiMWLVoEq9WKGTNmAABmzZqFoqIiPPLII1i9ejWcTieWL18Ou93OvZubUDR7Pj6nE71tbdBdMmdIN6eowmfNmjUAgLvvvjts/Ztvvolf/epXAIDnn38eGo0G5eXl8Pl8KC0txSuvvKLWpqSkYMuWLViwYAGsVisyMjJQUVGBlStX3tiWUMJRFAW63FykZGYi2NX17S8QAsHeXgghIr4ujJLXDZ3nIwvP80kewXPncPS//gu9ER4kyJ87F5af/5zhk6QG7Dwfom+jMRigieDE0X6dVznXi24+DB+KL0WJ6jlewu+H4HV+gwLDh+Iumvs5+06fRuDs2fg1QwmD4UNxpSgKUtLSgAjncAJuNw+3DxIMH4q79JEjob3G4fOgEDjZ04OTPT3oDATg7+gYuOZImus6w5koGtrsbGiucQ5XIBTC/3O54Dx3DueCQXznpZfwc50OEyZMQEpKCo983aQYPhR3SkoKtEYj/Fe5cNig0eBnF+aFfKEQnGlp2LhhA7Zv34758+djyJAhDKCbEL92UdwpKSkwTpp09XFFUZfUlBSMy83FimXLMH78ePz5z3/GyZMneauNmxDDh+JPUaCN4mTQnhMngJ4elJWV4eGHH8Zzzz2Hb775Jo4NkgwMH4o7RVHOP0pHp4uoXgQCCJ07B41Gg5KSEnzve9/DunXrEAqF4twpDSSGDw2I9JEjzx9yj4AIhXCutRUAoNFocP/99+P06dM4duxYPFukAcbwobgSQkAIcf5oV4RPMEUohO6WFnWeJzU1FWVlZdi+fTvnfm4iDB+KGyEEGhsb0dPTA41ej4xLnuMlhMA3Ph8+/vprHPN6EbooWILd3edvLHbB5MmT8cUXX4Q9KYWSG8OH4sbv9+Mvf/kLmpuboeh0SC0sVMeEEGjt7sbifftg37MH/2v3btScOIHghQDq+ewzBHt61PrMzEykpaVx4vkmwvChuDl27Bh27dqFt99++/xXr4uubhcA/k9TEw673QgKAW8ggL8dOYJDF8LFf/YsQhc9x0tRFBiNRpw7d26gN4PihOFDcXP8+HGMGTMGPp/v/P1dJk2CclEAeS95SKA/FILvGk8t1el0nPO5iTB8KG5sNhvuvfdePProo8jKykLa8OEwTp4MAFAA/MBigfaiM5e/azRiWGYmACDFYIByyQR1V1cXtBE+B4wSH/8lKW4yMjJQUFAAl8uFyZMnA1otbp87F13NzQh2daFi5Ehk6XT4v6dP47a0NMz/7neRl5oKALjlrrvCTkwMBoPwer3IiuJRPJTYuOdDcTVp0iQcOHBAPUEwtbAQt82ZA0Wvh/bCNV2vWq3435Mm4fb0dCgXbj5mnj07bM/n5MmTMBgM6tNxKfkxfChuFEXB6NGjcebMGfVJJ4pGg1vvuw+3V1RAl5MDRaOB5sJ1XRq9Hqbp0zFs8WLoLwoZIQRqa2thtVr5tesmwn9JiiuDwQCbzYaNGzfiySefhFarhUanQ15ZGUxTpsB74AB8Tic0aWnIHDsWmWPHXnYm9FdffYWDBw/ij3/8o6StoHhg+FBcKYqCe+65Bw6HAx9++CHuvvvu81ewazRIvf12pN5++zVf39nZiddeew0/+9nP+KSSmwy/dlHcGQwGPPnkk9i8eTM++uijiC4QFULA7XbjpZdewujRo3HHHXfwnj43GYYPxZ2iKLBYLKisrMR7772Hv//97+jo6LjqOTvBYBAHDx7EypUr8Z3vfAdz5syBJtLrwihp8KGBNGCEEOju7sbbb7+NPXv2YPTo0Zg2bRry8vKg1WrR1dWFzz//HLt37wYA/PKXv8S4cePUG41R4ovms8nwoQEnhMA333yDxsZGHDhwAGfPnkUoFEJaWhqGDRuGkpISjBw5ElqtlqGTZBg+RCQFH5dMRAmP4UNEUjB8iEgKhg8RScHwISIpGD5EJAXDh4ikYPgQkRQMHyKSIqrwWbNmDSZMmACj0Qij0Qir1Ypt27ap4729vbDb7cjNzUVmZibKy8vhcrnCfkdrayvKysqQnp6OvLw8LF26FH19fbHZGiJKGlGFz9ChQ7Fq1So0NDTg448/xj333IMHH3wQzc3NAICnnnoKmzdvxoYNG1BfX49Tp07hoYceUl8fDAZRVlYGv9+P3bt3Y926dVi7di1WrFgR260iosQnbtAtt9wiXn/9deF2u4VOpxMbNmxQx44cOSIACIfDIYQQYuvWrUKj0Qin06nWrFmzRhiNRuHz+a76Hr29vcLj8ahLW1ubACA8Hs+Ntk9EMeTxeCL+bF73nE8wGERNTQ26u7thtVrR0NCAQCAAm82m1owZMwaFhYVwOBwAAIfDgfHjx8NsNqs1paWl8Hq96t7TlVRXV8NkMqlLQUHB9bZNRAki6vBpampCZmYmDAYDnnjiCWzatAlFRUVwOp3Q6/XIzs4OqzebzerNw51OZ1jw9I/3j11NVVUVPB6PurS1tUXbNhElmKjv4Tx69Gg0NjbC4/Fg48aNqKioQH19fTx6UxkMBhgMhri+BxENrKjDR6/XY+TIkQCA4uJi7N+/Hy+88AIefvhh+P1+uN3usL0fl8sFi8UCALBYLNi3b1/Y7+s/GtZfQ0SDww2f5xMKheDz+VBcXAydToe6ujp1rKWlBa2trbBarQAAq9WKpqYmtLe3qzW1tbUwGo0oKiq60VaIKIlEtedTVVWF+++/H4WFhejs7MT69evxwQcf4P3334fJZMK8efNQWVmJnJwcGI1GLFq0CFarFTNmzAAAzJo1C0VFRXjkkUewevVqOJ1OLF++HHa7nV+riAaZqMKnvb0djz76KE6fPg2TyYQJEybg/fffx7333gsAeP7556HRaFBeXg6fz4fS0lK88sor6utTUlKwZcsWLFiwAFarFRkZGaioqMDKlStju1VElPB4D2ciihnew5mIEh7Dh4ikYPgQkRQMHyKSguFDRFIwfIhICoYPEUnB8CEiKRg+RCQFw4eIpGD4EJEUDB8ikoLhQ0RSMHyISAqGDxFJwfAhIikYPkQkBcOHiKRg+BCRFAwfIpKC4UNEUjB8iEgKhg8RScHwISIpGD5EJAXDh4ikYPgQkRQMHyKSguFDRFIwfIhICoYPEUnB8CEiKRg+RCQFw4eIpGD4EJEUNxQ+q1atgqIoWLJkibqut7cXdrsdubm5yMzMRHl5OVwuV9jrWltbUVZWhvT0dOTl5WHp0qXo6+u7kVaIKMlcd/js378fr732GiZMmBC2/qmnnsLmzZuxYcMG1NfX49SpU3jooYfU8WAwiLKyMvj9fuzevRvr1q3D2rVrsWLFiuvfCiJKPuI6dHZ2ilGjRona2lpx1113icWLFwshhHC73UKn04kNGzaotUeOHBEAhMPhEEIIsXXrVqHRaITT6VRr1qxZI4xGo/D5fFd8v97eXuHxeNSlra1NABAej+d62ieiOPF4PBF/Nq9rz8dut6OsrAw2my1sfUNDAwKBQNj6MWPGoLCwEA6HAwDgcDgwfvx4mM1mtaa0tBRerxfNzc1XfL/q6mqYTCZ1KSgouJ62iSiBRB0+NTU1+OSTT1BdXX3ZmNPphF6vR3Z2dth6s9kMp9Op1lwcPP3j/WNXUlVVBY/Hoy5tbW3Rtk1ECUYbTXFbWxsWL16M2tpapKamxqunyxgMBhgMhgF7PyKKv6j2fBoaGtDe3o4pU6ZAq9VCq9Wivr4eL774IrRaLcxmM/x+P9xud9jrXC4XLBYLAMBisVx29Kv/5/4aIrr5RRU+M2fORFNTExobG9Vl6tSpmDt3rvpnnU6Huro69TUtLS1obW2F1WoFAFitVjQ1NaG9vV2tqa2thdFoRFFRUYw2i4gSXVRfu7KysjBu3LiwdRkZGcjNzVXXz5s3D5WVlcjJyYHRaMSiRYtgtVoxY8YMAMCsWbNQVFSERx55BKtXr4bT6cTy5ctht9v51YpoEIkqfCLx/PPPQ6PRoLy8HD6fD6WlpXjllVfU8ZSUFGzZsgULFiyA1WpFRkYGKioqsHLlyli3QkQJTBFCCNlNRMvr9cJkMsHj8cBoNMpuh4guiOazyWu7iEgKhg8RScHwISIpGD5EJAXDh4ikYPgQkRQMHyKSguFDRFIwfIhICoYPEUnB8CEiKRg+RCQFw4eIpGD4EJEUDB8ikoLhQ0RSMHyISAqGDxFJwfAhIikYPkQkBcOHiKRg+BCRFAwfIpKC4UNEUjB8iEgKhg8RScHwISIpGD5EJAXDh4ikYPgQkRQMHyKSguFDRFIwfIhICoYPEUnB8CEiKRg+RCQFw4eIpNDKbuB6CCEAAF6vV3InRHSx/s9k/2f0WpIyfM6ePQsAKCgokNwJEV1JZ2cnTCbTNWuSMnxycnIAAK2trd+6gYnG6/WioKAAbW1tMBqNstuJGPseWMnatxACnZ2dyM/P/9bapAwfjeb8VJXJZEqqf5iLGY3GpOydfQ+sZOw70h0CTjgTkRQMHyKSIinDx2Aw4JlnnoHBYJDdStSStXf2PbCSte9oKCKSY2JERDGWlHs+RJT8GD5EJAXDh4ikYPgQkRQMHyKSIinD5+WXX8bw4cORmpqKkpIS7Nu3T2o/u3btwgMPPID8/HwoioJ33nknbFwIgRUrVuC2225DWloabDYbjh07FlbT0dGBuXPnwmg0Ijs7G/PmzUNXV1dc+66ursa0adOQlZWFvLw8zJ49Gy0tLWE1vb29sNvtyM3NRWZmJsrLy+FyucJqWltbUVZWhvT0dOTl5WHp0qXo6+uLW99r1qzBhAkT1LN/rVYrtm3bltA9X8mqVaugKAqWLFmSdL3HhEgyNTU1Qq/XizfeeEM0NzeL+fPni+zsbOFyuaT1tHXrVvH73/9evP322wKA2LRpU9j4qlWrhMlkEu+884749NNPxY9//GMxYsQIce7cObXmvvvuExMnThR79uwRH374oRg5cqSYM2dOXPsuLS0Vb775pjh06JBobGwUP/zhD0VhYaHo6upSa5544glRUFAg6urqxMcffyxmzJgh7rjjDnW8r69PjBs3TthsNnHgwAGxdetWMWTIEFFVVRW3vt977z3xz3/+U/z73/8WLS0t4ne/+53Q6XTi0KFDCdvzpfbt2yeGDx8uJkyYIBYvXqyuT4beYyXpwmf69OnCbrerPweDQZGfny+qq6sldvU/Lg2fUCgkLBaLePbZZ9V1brdbGAwG8dZbbwkhhDh8+LAAIPbv36/WbNu2TSiKIk6ePDlgvbe3twsAor6+Xu1Tp9OJDRs2qDVHjhwRAITD4RBCnA9ejUYjnE6nWrNmzRphNBqFz+cbsN5vueUW8frrrydFz52dnWLUqFGitrZW3HXXXWr4JEPvsZRUX7v8fj8aGhpgs9nUdRqNBjabDQ6HQ2JnV3fixAk4nc6wnk0mE0pKStSeHQ4HsrOzMXXqVLXGZrNBo9Fg7969A9arx+MB8D93DWhoaEAgEAjrfcyYMSgsLAzrffz48TCbzWpNaWkpvF4vmpub495zMBhETU0Nuru7YbVak6Jnu92OsrKysB6B5Pj7jqWkuqr966+/RjAYDPuLBwCz2YyjR49K6uranE4nAFyx5/4xp9OJvLy8sHGtVoucnBy1Jt5CoRCWLFmCO++8E+PGjVP70uv1yM7OvmbvV9q2/rF4aWpqgtVqRW9vLzIzM7Fp0yYUFRWhsbExYXsGgJqaGnzyySfYv3//ZWOJ/PcdD0kVPhQ/drsdhw4dwkcffSS7lYiMHj0ajY2N8Hg82LhxIyoqKlBfXy+7rWtqa2vD4sWLUVtbi9TUVNntSJdUX7uGDBmClJSUy2b/XS4XLBaLpK6urb+va/VssVjQ3t4eNt7X14eOjo4B2a6FCxdiy5Yt2LlzJ4YOHaqut1gs8Pv9cLvd1+z9StvWPxYver0eI0eORHFxMaqrqzFx4kS88MILCd1zQ0MD2tvbMWXKFGi1Wmi1WtTX1+PFF1+EVquF2WxO2N7jIanCR6/Xo7i4GHV1deq6UCiEuro6WK1WiZ1d3YgRI2CxWMJ69nq92Lt3r9qz1WqF2+1GQ0ODWrNjxw6EQiGUlJTErTchBBYuXIhNmzZhx44dGDFiRNh4cXExdDpdWO8tLS1obW0N672pqSksPGtra2E0GlFUVBS33i8VCoXg8/kSuueZM2eiqakJjY2N6jJ16lTMnTtX/XOi9h4Xsme8o1VTUyMMBoNYu3atOHz4sHj88cdFdnZ22Oz/QOvs7BQHDhwQBw4cEADEc889Jw4cOCC+/PJLIcT5Q+3Z2dni3XffFQcPHhQPPvjgFQ+1T548Wezdu1d89NFHYtSoUXE/1L5gwQJhMpnEBx98IE6fPq0uPT09as0TTzwhCgsLxY4dO8THH38srFarsFqt6nj/od9Zs2aJxsZGsX37dnHrrbfG9dDvsmXLRH19vThx4oQ4ePCgWLZsmVAURfzrX/9K2J6v5uKjXcnW+41KuvARQoiXXnpJFBYWCr1eL6ZPny727NkjtZ+dO3cKAJctFRUVQojzh9uffvppYTabhcFgEDNnzhQtLS1hv+Ps2bNizpw5IjMzUxiNRvHYY4+Jzs7OuPZ9pZ4BiDfffFOtOXfunHjyySfFLbfcItLT08VPfvITcfr06bDf88UXX4j7779fpKWliSFDhojf/OY3IhAIxK3vX//612LYsGFCr9eLW2+9VcycOVMNnkTt+WouDZ9k6v1G8X4+RCRFUs35ENHNg+FDRFIwfIhICoYPEUnB8CEiKRg+RCQFw4eIpGD4EJEUDB8ikoLhQ0RSMHyISIr/D7sMxsKwI65RAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "169.81928665726804"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
